import numpy as np
import torch
a = [0.1,0.2,0.3,0.4]
aa = torch.tensor(a)
res = torch.max(aa.data,0)[1]
print(res)
# pred = np.array([1,2],dtype=int)
# pred = np.append(pred, res)
# print(pred)